(논문) Chest X-ray
kaggle_chest x-ray data를 활용한 CNN and CAM
import torch
from fastai.vision.all import *
import cv2
import fastbook
from fastbook import *
from fastai.vision.widgets import *
path=Path('/home/khy/chest_xray/chest_xray')
path.ls()
files=get_image_files(path)
files
dls = ImageDataLoaders.from_folder(path, train='train', valid_pct=0.2, item_tfms=Resize(224))
dls.vocab
dls.show_batch(max_n=16)
learn=cnn_learner(dls,resnet34,metrics=error_rate)
net1=learn.model[0]
net2=learn.model[1]
net2 = torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d(output_size=1),
torch.nn.Flatten(),
torch.nn.Linear(512,out_features=2,bias=False))
net=torch.nn.Sequential(net1,net2)
lrnr2=Learner(dls,net,metrics=accuracy)
lrnr2.fine_tune(200)
fig, ax = plt.subplots(5,5)
k=0
for i in range(5):
for j in range(5):
x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
a,b = net(x).tolist()[0]
normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) , np.exp(b)/ (np.exp(a)+np.exp(b))
if normalprob>pneumoniaprob:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("normal(%s)" % normalprob.round(5))
else:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
k=k+1
fig.set_figwidth(16)
fig.set_figheight(16)
fig.tight_layout()
fig, ax = plt.subplots(5,5)
k=3000
for i in range(5):
for j in range(5):
x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
a,b = net(x).tolist()[0]
normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) , np.exp(b)/ (np.exp(a)+np.exp(b))
if normalprob>pneumoniaprob:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("normal(%s)" % normalprob.round(5))
else:
dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
k=k+1
fig.set_figwidth(16)
fig.set_figheight(16)
fig.tight_layout()
get_image_files(path)[3021]
img = PILImage.create(get_image_files(path)[3021])
img
x, = first(dls.test_dl([img])) #이미지 텐서화
x.shape
a=net(x.to('cpu')).tolist()[0][0]
b=net(x.to('cpu')).tolist()[0][1]
np.exp(a)/(np.exp(a)+np.exp(b)), np.exp(b)/(np.exp(a)+np.exp(b))
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x.to('cpu')).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2)
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
#
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
test=camimg[1]-torch.min(camimg[1])
A1=torch.exp(-0.02*test)
A2=1-A1
fig, (ax1, ax2) = plt.subplots(1,2)
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE1 WEIGHT")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE1 RES WEIGHT")
#
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
X1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv2.resize(X1,(224,224),interpolation=cv2.INTER_LINEAR))
x1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv2.resize(X12,(224,224),interpolation=cv2.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12#-torch.min(x.squeeze().to('cpu')*Y12)
- 1st CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1)
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2)
(x12*0.3).squeeze().show(ax=ax1) #MODE1
(x1*0.2).squeeze().show(ax=ax2) #MODE1_res
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
camimg1 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())
-
CAM
- mode1_res에 CAM 결과 올리기
fig, (ax1,ax2) = plt.subplots(1,2)
#
(x1*0.2).squeeze().show(ax=ax1)
ax1.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
(x1*0.2).squeeze().show(ax=ax2)
ax2.imshow(camimg1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
fig, (ax1,ax2) = plt.subplots(1,2)
#
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
x1.shape
a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
test1=camimg1[1]-torch.min(camimg1[1])
A3=torch.exp(-0.04*test1)
A4=1-A3
fig, (ax1, ax2) = plt.subplots(1,2)
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A3.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='binary')
ax1.set_title("MODE2 WEIGHT")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A4.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='binary')
ax2.set_title("MODE2 RES WEIGHT")
#
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
X3=np.array(A3.to("cpu").detach(),dtype=np.float32)
Y3=torch.Tensor(cv2.resize(X3,(224,224),interpolation=cv2.INTER_LINEAR))
x3=x.squeeze().to('cpu')*Y1*Y3-torch.min(x.squeeze().to('cpu')*Y1*Y3)
#x1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)
X4=np.array(A4.to("cpu").detach(),dtype=np.float32)
Y4=torch.Tensor(cv2.resize(X4,(224,224),interpolation=cv2.INTER_LINEAR))
x4=x.squeeze().to('cpu')*Y12*Y4
#x12=x.squeeze().to('cpu')*Y12
- 2nd CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1)
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2)
(x12*0.5).squeeze().show(ax=ax1)
(x1*0.3).squeeze().show(ax=ax2)
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2)
(x4*3).squeeze().show(ax=ax1)
(x3*0.3).squeeze().show(ax=ax2)
ax1.set_title("MODE2")
ax2.set_title("MODE2 RES")
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
x3=x3.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
camimg2 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x3).squeeze())
- CAM
fig, (ax1,ax2) = plt.subplots(1,2)
#
(x3*0.3).squeeze().show(ax=ax1)
ax1.imshow(camimg2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
(x3*0.3).squeeze().show(ax=ax2)
ax2.imshow(camimg2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)
fig.set_figheight(8)
fig.tight_layout()
a2=net(x3).tolist()[0][0]
b2=net(x3).tolist()[0][1]
np.exp(a2)/(np.exp(a2)+np.exp(b2)), np.exp(b2)/(np.exp(a2)+np.exp(b2))